"""
https://github.com/EugenHotaj/pytorch-generative/blob/master/pytorch_generative/models/kde.py
"""

import abc
import numpy as np
import torch
from torch import nn
# from pytorch_generative.models import base


class Kernel(abc.ABC, nn.Module):
    """Base class which defines the interface for all kernels."""

    def __init__(self, bandwidth):
        """Initializes a new Kernel.

        Args:
            bandwidth: The kernel's (band)width.
        """
        super().__init__()
        self.bandwidth = bandwidth

    def _diffs(self, test_Xs, train_Xs):
        """Computes difference between each x in test_Xs with all train_Xs."""
        # test_Xs = test_Xs.view((test_Xs.shape[0], 1, *test_Xs.shape[1:]))
        # train_Xs = train_Xs.view((1, train_Xs.shape[0], *train_Xs.shape[1:]))
        test_Xs = test_Xs.unsqueeze(1)
        train_Xs = train_Xs.unsqueeze(0)
        return test_Xs - train_Xs

    @abc.abstractmethod
    def forward(self, test_Xs, train_Xs):
        """Computes log p(x) for each x in test_Xs given train_Xs."""

    @abc.abstractmethod
    def sample(self, train_Xs):
        """Generates samples from the kernel distribution."""


class ParzenWindowKernel(Kernel):
    """Implementation of the Parzen window kernel."""

    def forward(self, test_Xs, train_Xs):
        abs_diffs = torch.abs(self._diffs(test_Xs, train_Xs))
        dims = tuple(range(len(abs_diffs.shape))[2:])
        dim = np.prod(abs_diffs.shape[2:])
        inside = torch.sum(abs_diffs / self.bandwidth <= 0.5, dim=dims) == dim
        coef = 1 / self.bandwidth ** dim
        return torch.log((coef * inside).mean(dim=1))

    def sample(self, train_Xs):
        noise = (torch.rand(train_Xs.shape) - 0.5) * self.bandwidth
        return train_Xs + noise


class GaussianKernel(Kernel):
    """Implementation of the Gaussian kernel."""

    def forward(self, test_Xs, train_Xs):
        n, d = train_Xs.shape
        n, h = torch.tensor(n, dtype=torch.float32).cuda(), torch.tensor(self.bandwidth).cuda()
        pi = torch.tensor(np.pi).cuda()
        Z = 0.5 * d * torch.log(2 * pi) + d * torch.log(h) + torch.log(n)
        diffs = self._diffs(test_Xs, train_Xs) / h
        log_exp = -0.5 * torch.norm(diffs, p=2, dim=-1) ** 2
        return torch.logsumexp(log_exp - Z, dim=-1)

    def sample(self, train_Xs):
        noise = torch.randn(train_Xs.shape) * self.bandwidth
        return train_Xs + noise


class KernelDensityEstimator:
    """The KernelDensityEstimator model."""

    def __init__(self, train_Xs, std=0.1):
        """Initializes a new KernelDensityEstimator.

        Args:
            train_Xs: The "training" data to use when estimating probabilities.
            kernel: The kernel to place on each of the train_Xs.
        """
        super().__init__()
        self.kernel = GaussianKernel(bandwidth=std)
        self.train_Xs = train_Xs
        assert len(self.train_Xs.shape) == 2, "Input cannot have more than two axes."

    def __call__(self, x):
        return self.kernel(x, self.train_Xs)
